import numpy as np
import time
import torch
import torch.optim as optim
import cvxpy as cp
from scipy.optimize import minimize
#from numba import njit
from gurobipy import Model, GRB, QuadExpr, LinExpr

# Define the solve_weighting_vector function using cvxpy
def solve_weighting_vector_cvxpy(Kt, mu, PF, P):
    PF_np = PF.cpu().detach().numpy()
    Kt_np = Kt.cpu().detach().numpy()
    P_np = P.cpu().detach().numpy()

    n = PF_np.shape[0]
    lambdas = cp.Variable(n)
    grad_sum = Kt_np @ cp.multiply(lambdas, cp.Constant(P_np))
    norm_squared = cp.sum_squares(grad_sum)
    objective = cp.Minimize(norm_squared - mu * (PF_np @ lambdas))
    constraints = [cp.sum(lambdas) == 1, lambdas >= 0]
    problem = cp.Problem(objective, constraints)
    problem.solve()
    print("optimal value with OSQP:", problem.value)
    return torch.tensor(lambdas.value, dtype=torch.float32)

def solve_weighting_vector_gurobi(Kt, mu, PF, P):
    PF_np = PF.cpu().detach().numpy()
    Kt_np = Kt.cpu().detach().numpy()
    P = P.cpu().detach().numpy()

    num_vars = len(PF_np)

    # Create a new Gurobi model
    model = Model("weighting_vector_optimization")

    # Add variables λ_i, all bounded between 0 and infinity
    lambdas = model.addVars(num_vars, lb=0, name="lambda")

    # Set the objective function: min ||Kt @ λ||^2 - mu * λ^T * PF
    grad_sum = Kt_np @ (np.array([lambdas[i] for i in range(num_vars)]) * P)
    norm_sq = QuadExpr()

    for i in range(num_vars):
        for j in range(num_vars):
            norm_sq += grad_sum[i] * grad_sum[j]

    objective = norm_sq - mu * sum(lambdas[i] * PF_np[i] for i in range(num_vars))
    model.setObjective(objective, GRB.MINIMIZE)

    # Add constraint: sum of λs should be 1
    model.addConstr(sum(lambdas[i] for i in range(num_vars)) == 1)

    # Optimize the model
    model.optimize()

    # Extract the optimal λ values
    lambda_star = torch.tensor([lambdas[i].x for i in range(num_vars)], dtype=torch.float32)

    return lambda_star


def solve_weighting_vector_scipy(Kt, mu, PF, P):
    """
    Solve for the weighting vector λ* using SciPy's minimize function.

    Parameters:
        Kt: Matrix of gradients (torch.Tensor).
        u: Trade-off parameter (float).
        PF: Preference vector times loss function (torch.Tensor).
    Returns:
        Optimal λ vector (torch.Tensor).
    """

    # Convert PF and Kt to NumPy for compatibility with SciPy
    PF_np = PF.cpu().detach().numpy()
    Kt_np = Kt.cpu().detach().numpy()
    P_np = P.cpu().detach().numpy()

    options = {
    #'disp': True,  # Display optimization output
    'ftol': 1e-5,  # Function tolerance
    }

    def objective(lambdas):
        # Compute the objective value using NumPy
        grad_sum = Kt_np @ (lambdas * P_np)
        return torch.sum(torch.tensor(grad_sum) ** 2).item() - mu * (lambdas @ PF_np)

    # Constraint: sum of λs should be 1
    constraints = [{'type': 'eq', 'fun': lambda lambdas: np.sum(lambdas) - 1}]

    # Bounds: λ >= 0
    bounds = [(0, None)] * len(PF_np)

    # Initial guess for λ: uniform distribution
    initial_guess = np.ones(len(PF_np)) / len(PF_np)

    # Perform the optimization
    result = minimize(objective, x0=initial_guess, bounds=bounds, constraints=constraints, options = options)

    # Convert the result back to a PyTorch tensor
    lambda_star = torch.tensor(result.x, dtype=torch.float32)
    #print(
    #    f"Optimal λ vector: {lambda_star.numpy()}, Objective function value: {result.fun}, Convergence status: {'Converged' if result.success else 'Not converged'} ")

    return lambda_star

# def solve_weighting_vector_fast_jacobian(Kt, mu, PF, P, logger=None):
#     """
#     Solve for the weighting vector λ* using SciPy's minimize function.

#     Parameters:
#         Kt: Matrix of gradients (torch.Tensor).
#         u: Trade-off parameter (float).
#         PF: Preference vector times loss function (torch.Tensor).
#     Returns:
#         Optimal λ vector (torch.Tensor).
#     """

#     # Convert PF and Kt to NumPy for compatibility with SciPy
#     PF_np = PF.cpu().detach().numpy()
#     Kt_np = Kt.cpu().detach().numpy()
#     P_np = P.cpu().detach().numpy()

#     PF_np = PF_np.astype(np.float64)
#     Kt_np = Kt_np.astype(np.float64)
#     P_np = P_np.astype(np.float64)
    
#     options = {
#     #'disp': True,  # Display optimization output
#     'ftol': 1e-5,  # Function tolerance
#     }
#     @njit
#     def objective(lambdas):
#         # Compute the objective value using NumPy
#         grad_sum = Kt_np @ np.multiply(lambdas, P_np)
#         norm_grad_sum = np.linalg.norm(grad_sum, ord=2)
#         return norm_grad_sum ** 2 - mu* np.dot(lambdas, PF_np)
#         # Convert the objective function to be compatible with SciPy's minimize

#     def scipy_objective(lambdas):
#         return objective(np.array(lambdas))

#     # Constraint: sum of λs should be 1
#     constraints = [{'type': 'eq', 'fun': lambda lambdas: np.sum(lambdas) - 1}]

#     # Bounds: λ >= 0
#     bounds = [(0, None)] * len(PF_np)

#     # Initial guess for λ: uniform distribution
#     # Generate random values
#     random_vector = np.random.uniform(0, 1, len(PF_np))
#     # Normalize the vector to make the sum equal to 1
#     initial_guess = random_vector/ np.sum(random_vector)
#     #initial_guess = np.ones(len(PF_np)) / len(PF_np)

#     # Perform the optimization
#     result = minimize(scipy_objective, x0=initial_guess, bounds=bounds, constraints=constraints, method='SLSQP', options=options)
#     # print(result)
#     # Convert the result back to a PyTorch tensor
#     lambda_star = torch.tensor(result.x, dtype=torch.float32)
#     # print(
#         # f"Optimal λ vector: {lambda_star.numpy()}, Objective function value: {result.fun}, Convergence status: {'Converged' if result.success else 'Not converged'} ")
#     if logger:
#         logger.info(f"Optimal λ vector: {lambda_star.numpy()}, Objective function value: {result.fun}, Convergence status: {'Converged' if result.success else 'Not converged'} ")
#     print(result.fun)
#     return lambda_star

# def solve_weighting_vector_pytorch(Kt, mu, PF):
#     """
#     Solve for the weighting vector λ* using PyTorch optimizers.

#     Parameters:
#         Kt: Matrix of gradients (torch.Tensor).
#         mu: Trade-off parameter (float).
#         PF: Preference vector times loss function (torch.Tensor).
#     Returns:
#         Optimal λ vector (torch.Tensor).
#     """

#     # Convert PF and Kt to GPU if available
#     device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#     Kt = Kt.to(device)
#     PF = PF.to(device)

#     # Number of dimensions
#     n = PF.shape[0]

#     # Initialize λ with uniform distribution
#     lambdas = torch.ones(n, device=device) / n
#     lambdas.requires_grad = True

#     # Define optimizer
#     optimizer = optim.Adam([lambdas], lr=1e-2)  # Adam optimizer with a learning rate

#     # Define the loss function
#     def objective():
#         grad_sum = torch.matmul(Kt, lambdas)
#         norm_squared = torch.norm(grad_sum, p=2) ** 2
#         return mu * norm_squared - torch.matmul(lambdas, PF)

#     # Optimization loop
#     for _ in range(1000):  # Adjust the number of iterations as needed
#         optimizer.zero_grad()
#         loss = objective()
#         loss.backward()
#         optimizer.step()

#         # Ensure λ is non-negative and normalize to sum to 1
#         with torch.no_grad():
#             lambdas.data = torch.clamp(lambdas, min=0)  # Enforce non-negativity
#             lambdas.data = lambdas.data / lambdas.data.sum()  # Normalize to sum to 1

#     return lambdas.detach()

# Benchmarking function
def benchmark_solver(solver_func, Kt, mu, PF, name):
    start_time = time.time()
    solver_func(Kt, mu, PF, PF)
    elapsed_time = time.time() - start_time
    print(f"{name} Time: {elapsed_time:.4f} seconds")


def matrix_sqrt(matrix):
    U, S, Vh = torch.linalg.svd(matrix)
    # Create a diagonal matrix for the square roots of the singular values
    sqrt_S = torch.diag(torch.sqrt(S))
    # Reconstruct the matrix square root
    sqrt_matrix = U @ sqrt_S @ Vh  # Use Vh for the transpose of V
    return sqrt_matrix

def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    PF = torch.tensor([0.3, 0.3, 0.4, 0.4], dtype=torch.float32).to(device)
    P = torch.tensor([0.3, 0.3, 0.2, 0.2], dtype=torch.float32).to(device)
    # Stack the gradients into a matrix Gt
    Gt = torch.tensor([[0.1, 0.2, 0.05, 1.2], 
                    [6, 4, 20, 12],
                    [1.2, 2, 1, 5.4],
                    [0.5, 0.01, 0.1, 0.3], 
                    [1, 0.1, 2, 0.4]], dtype=torch.float32).to(device)

    # Compute G^T G
    G_T_G = torch.mm(Gt.T, Gt)
    Kt = matrix_sqrt(G_T_G)
    print(Kt)
    mu = 1
    # Timing solve_weighting_vector
    start_time = time.time()
    lambda_star = solve_weighting_vector_scipy(Kt, mu, PF, P)
    # Timing solve_weighting_vector1
    expected_lambda_star = torch.tensor([-4.5833e-20,  1.0000e+00, -3.6783e-20, -1.0633e-21], dtype=torch.float32)
    print(f"  solve_weighting_vector_scipy took {time.time() - start_time :.6f} seconds")
    print(lambda_star)
    if torch.allclose(lambda_star, expected_lambda_star, atol=1e-3):
        print("Test scipy Passed: Basic Functionality")
    else:
        print("Test scipy Failed: Basic Functionality")

    start_time = time.time()
    lambda_star = solve_weighting_vector_cvxpy(Kt, mu, PF, P)
    # Timing solve_weighting_vector1
    print(f" solve_weighting_vector1 took {time.time() - start_time :.6f} seconds")
    print(lambda_star)
    if torch.allclose(lambda_star, expected_lambda_star, atol=1e-3):
        print("Test 2 Passed: Basic Functionality")
    else:
        print("Test 2 Failed: Basic Functionality")

    start_time = time.time()
    lambda_star = solve_weighting_vector_gurobi(Kt, mu, PF, P)
    # Timing solve_weighting_vector1
    print(f"  solve_weighting_vector_gurobi took {time.time() - start_time :.6f} seconds")
    print(lambda_star)
    if torch.allclose(lambda_star, expected_lambda_star, atol=1e-3):
        print("Test 3 Passed: Basic Functionality")
    else:
        print("Test 3 Failed: Basic Functionality")

    start_time = time.time()
    #lambda_star = solve_weighting_vector_fast_jacobian(Kt, mu, PF, PF)
    # Timing solve_weighting_vector1
    #print(f"  solve_weighting_vector_fast_jacobian took {time.time() - start_time :.6f} seconds")
    #print(lambda_star)
    if torch.allclose(lambda_star, expected_lambda_star, atol=1e-3):
        print("Test 4 Passed: Basic Functionality")
    else:
        print("Test 4 Failed: Basic Functionality")

    ##### Test Case 2
    print("--------Test case 2--------")
    P = np.array([0.3, 0.3, 0.2, 0.2])
    P_tensor= torch.tensor(P, dtype=torch.float32).to(device)
    accumulated_losses = np.array([1.6, 1.1, 1.5, 0.5])
    Gt = torch.tensor([[0.1, 0.2, 0.05, 1.2], 
                       [6, 4, 20, 12],
                       [1.2, 2, 1, 5.4],
                       [0.5, 0.01, 0.1, 0.3], 
                       [1, 0.1, 2, 0.4]], dtype=torch.float32).to(device)
    print(Gt.T.shape)
    print(Gt.shape)
    mu = 0.1 

    # Compute the matrix square root using SVD
    #torch.backends.cuda.set_preferred_linalg_library('torch.linalg')
    G_T_G = torch.mm(Gt.T, Gt)
    matrix_u, vector_s, matrix_v = torch.svd(G_T_G)
    # Compute the square root of S^2
    S_sqrt = torch.diag(torch.sqrt(vector_s))
    sqrt_Gt_tG = torch.matmul(matrix_u, torch.matmul(S_sqrt, matrix_v.T))

    sqrt_P_diag = torch.diag(torch.sqrt(P_tensor))
    Kt = matrix_sqrt(G_T_G)
    is_equal = torch.allclose(G_T_G, torch.mm(Kt, Kt))
    print("matrix is :..")
    print(Kt, is_equal)
    is_equal = torch.allclose(G_T_G, torch.mm(sqrt_Gt_tG, sqrt_Gt_tG))
    print(sqrt_Gt_tG, is_equal)
    
    Hadamard_product = P * accumulated_losses
    Hadamard_product_tensor = torch.tensor(Hadamard_product, dtype=torch.float32).to(device)
    #Compute Kt
    Kt = sqrt_P_diag @ sqrt_Gt_tG @ sqrt_P_diag
    print("Kt:")
    print(Kt)
    
    # Print tensor information
    #print("Solve fast jacobian")
    #lambda_star = solve_weighting_vector_fast_jacobian(Kt, mu, Hadamard_product_tensor, P_tensor)
    #print(lambda_star)
    # Timing solve_weighting_vector
    print("weight scipy")
    lambda_star = solve_weighting_vector_scipy(Kt, mu, Hadamard_product_tensor, P_tensor)
    # Timing solve_weighting_vector1
    print(lambda_star)
    print("vector cvxpy")
    lambda_star = solve_weighting_vector_cvxpy(Kt, mu, Hadamard_product_tensor, P_tensor)
    print(lambda_star)
    print("vector gurobi")
    lambda_star = solve_weighting_vector_gurobi(Kt, mu, Hadamard_product_tensor, P_tensor)
    print(lambda_star)

    ### Testcase 4

    dimension = 1000

    # Generate synthetic data
    Kt = torch.eye(dimension, dtype=torch.float32)
    PF = torch.ones(dimension, dtype=torch.float32)
    mu = 1.0

    print("Benchmarking solve_weighting_vector with cvxpy...")
    benchmark_solver(solve_weighting_vector_cvxpy, Kt, mu, PF, "cvxpy")

    print("Benchmarking solve_weighting_vector with spicy...")
    benchmark_solver(solve_weighting_vector_scipy, Kt, mu, PF, "spicy")

    # print("Benchmarking solve_weighting_vector_gurobi with pytorch...")
    # benchmark_solver(solve_weighting_vector_pytorch, Kt, mu, PF, "pytorch")

    print("Benchmarking solve_weighting_vector with jacobian...")
    #benchmark_solver(solve_weighting_vector_fast_jacobian, Kt, mu, PF, "pytorch")

    print("Verify sqrt of matrix transpose times matrix")
    Gt = Kt = torch.tensor([
        [3, 1, 1, 2],
        [2, 3, 4, 3],
        [2, 3, 1, 4],
        [1, 2, 3, 2],
        [3, 1, 1, 1]
    ], dtype=torch.float32).to(device)
    

    # Compute the square root of S^2
    # torch.backends.cuda.set_preferred_linalg_library('torch.linalg')
    G_T_G = torch.mm(Gt.T, Gt)
    matrix_u, vector_s, matrix_v = torch.svd(G_T_G)
    # Compute the square root of S^2
    S_sqrt = torch.diag(torch.sqrt(vector_s))
    sqrt_Gt_tG = torch.matmul(matrix_u, torch.matmul(S_sqrt, matrix_v.T))
    # Compute S^T S
    G_T_G = torch.mm(Gt.T, Gt)
    S_T_S = torch.mm(sqrt_Gt_tG.T, sqrt_Gt_tG)
    print("Compare GT G with the sqrt matrix")
    print(torch.allclose(S_T_S, G_T_G))
if __name__ == "__main__":
    main()